#include <iostream>
#include <fstream>
#include <string>

#undef NDEBUG

#include <cassert>
#include <locale>

#define NOMINMAX

#include "Windows.h"

struct Config {
    std::string target_dll_path;
    std::string locate_dll_path;
    enum LocateMode {
        ByName
    } locate_mode = LocateMode::ByName;
    std::string output_path;
    bool use_dll_name_as_output_name = false;
} config;

int main(int argc, char *args[]) {
    std::locale::global(std::locale("en_US.UTF-8"));

    // Parse arguments
    if (argc == 1) {
        std::cerr << "Usage: " << args[0]
                  << " <target_dll_path> [locate_dll_path] [-m locate_mode] [-o output_path] [-n: use dll name as output name]\n";
        return 1;
    }

    enum class ParseState {
        Unknown,
        LocateModeSwitch,
        OutputPathSwitch,
    } state = ParseState::Unknown;

    for (int i = 1; i < argc; i++) {
        const std::string &arg = args[i];
        switch (state) {
            case ParseState::Unknown:
                if (arg == "-m") {
                    state = ParseState::LocateModeSwitch;
                } else if (arg == "-o") {
                    state = ParseState::OutputPathSwitch;
                } else if (arg == "-n") {
                    config.use_dll_name_as_output_name = true;
                } else if (config.target_dll_path.empty()) {
                    config.target_dll_path = arg;
                } else if (config.locate_dll_path.empty()) {
                    config.locate_dll_path = arg;
                } else {
                    std::cerr << "Unexpected argument: " << arg << "\n";
                    return 1;
                }
                break;
            case ParseState::LocateModeSwitch:
                if (arg == "name") {
                    config.locate_mode = Config::LocateMode::ByName;
                } else {
                    std::cerr << "Unknown locate mode: " << arg << "\n";
                    return 1;
                }
                state = ParseState::Unknown;
                break;
            default:
                std::cerr << "Unexpected argument: " << arg << "\n";
                return 1;
        }
    }

    config.target_dll_path = config.target_dll_path.empty() ? "version.dll" : config.target_dll_path;

    std::string code = R"(
// This file is generated by blook
// Neither touch/modify nor include this file in your project!

#include "Windows.h"

)";


    HMODULE lib = LoadLibraryExA(config.target_dll_path.data(), nullptr, DONT_RESOLVE_DLL_REFERENCES);
    assert(((PIMAGE_DOS_HEADER) lib)->e_magic == IMAGE_DOS_SIGNATURE);
    auto header =
            (PIMAGE_NT_HEADERS) ((BYTE *) lib + ((PIMAGE_DOS_HEADER) lib)->e_lfanew);
    assert(header->Signature == IMAGE_NT_SIGNATURE);
    assert(header->OptionalHeader.NumberOfRvaAndSizes > 0);
    auto exports =
            (PIMAGE_EXPORT_DIRECTORY) ((BYTE *) lib +
                                       header->OptionalHeader
                                               .DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT]
                                               .VirtualAddress);
    assert(exports->AddressOfNames != 0);
    auto names = (uint32_t *) ((BYTE *) lib + exports->AddressOfNames);

    code += "// blook: USE_STATIC_INITIALIZATION_TO_INIT_EXPORTS";
    const auto numOfNames = exports->NumberOfNames;
    if (config.locate_mode == Config::LocateMode::ByName) {
        code += std::format(R"(
void (*blookHijackFuncs[{0}])() = {{}};
const char* blookHijackFuncNames[] = {{
)", numOfNames);
        for (int i = 0; i < exports->NumberOfNames; i++) {
            std::string exportName = (char *) ((BYTE *) lib + names[i]);
            code += std::format(" \"{0}\",\n", exportName);
        }

        code += R"(};

int blookInit = ([] {)";

        if (config.locate_dll_path.empty()) {
            code += std::format(R"(
    char system_dir_path[MAX_PATH];
    GetSystemDirectoryA(system_dir_path, MAX_PATH);
    strcat_s(system_dir_path, "\\{}");
    HMODULE lib = LoadLibraryA(system_dir_path);
)", config.target_dll_path);
        } else {
            code += std::format(R"(
    HMODULE lib = LoadLibraryA("{0}");
)", config.locate_dll_path);
        }

        code += R"(
    for (int i = 0; i < sizeof(blookHijackFuncs) / sizeof(blookHijackFuncs[0]); i++) {
        blookHijackFuncs[i] = (void (*)()) GetProcAddress(lib, blookHijackFuncNames[i]);
    }

    return 0;
})();

)";
    } else {
        code += std::format(R"(
void (*blookHijackFuncs[{}])() = {{}};)", numOfNames);

    }

    code += "// blook: GENERATED_PLACEHOLDER_FUNCTION_EXPORTS";
    for (int i = 0; i < exports->NumberOfNames; i++) {
        std::string exportName = (char *) ((BYTE *) lib + names[i]);

        code += std::format(R"(
#pragma comment(linker, "/EXPORT:{1}=BLOOK_PLACEHOLDER_{0}")
extern "C" void BLOOK_PLACEHOLDER_{0}() {{ return (blookHijackFuncs[{2}])(); }}
)", i + 1, exportName, i);
    }

    if (!config.output_path.empty()) {
        std::ofstream out(config.output_path);
        out << code;
    } else if (config.use_dll_name_as_output_name) {
        std::string outputName = config.target_dll_path;
        outputName = outputName.substr(outputName.find_last_of("\\/") + 1);
        outputName = outputName.substr(0, outputName.find_last_of("."));
        outputName += ".cc";
        std::ofstream out(outputName);
        out << code;
    } else {
        std::cout << code;
    }
}